from Global_variable_setting import optimizer_initial_lr, loss_values, metrics_values
from Global_variable_setting import weight_decay, dropout

import tensorflow as tf


def get_cnn():
    input = tf.keras.layers.Input(shape=(160, 160, 1))

    # residual block 0
    x = input
    shortcut = x

    # x = tf.keras.layers.ZeroPadding2D()(x)
    # x = tf.pad(x, [[1, 1], [1, 1], [0, 0]])
    x = tf.keras.layers.Conv2D(32, (4, 4), strides=(2, 2), use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    shortcut = tf.keras.layers.Conv2D(32, (1, 1), strides=(2, 2), padding='valid')(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])

    x = tf.keras.layers.LeakyReLU()(x)

    # residual block 1
    shortcut = x

    # x = tf.keras.layers.ZeroPadding2D()(x)
    # x = tf.pad(x, [[1, 1], [1, 1], [0, 0]])
    x = tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    shortcut = tf.keras.layers.Conv2D(64, (1, 1), strides=(2, 2), padding='valid')(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])

    x = tf.keras.layers.LeakyReLU()(x)

    # residual block 2
    shortcut = x

    # x = tf.keras.layers.ZeroPadding2D()(x)
    # x = tf.pad(x, [[1, 1], [1, 1], [0, 0]])
    x = tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    shortcut = tf.keras.layers.Conv2D(128, (1, 1), strides=(2, 2), padding='valid')(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])

    x = tf.keras.layers.LeakyReLU()(x)

    # residual block 3
    shortcut = x

    # x = tf.keras.layers.ZeroPadding2D()(x)
    # x = tf.pad(x, [[1, 1], [1, 1], [0, 0]])
    x = tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    shortcut = tf.keras.layers.Conv2D(256, (1, 1), strides=(2, 2), padding='valid')(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])

    x = tf.keras.layers.LeakyReLU()(x)

    # residual block 4
    shortcut = x

    # x = tf.keras.layers.ZeroPadding2D()(x)
    # x = tf.pad(x, [[1, 1], [1, 1], [0, 0]])
    x = tf.keras.layers.Conv2D(512, (4, 4), strides=(2, 2), use_bias=False, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    shortcut = tf.keras.layers.Conv2D(512, (1, 1), strides=(2, 2), padding='valid')(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])

    x = tf.keras.layers.LeakyReLU()(x)

    # final processing
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(dropout)(x)
    x = tf.keras.layers.Dense(512)(x)
    x = tf.keras.layers.LeakyReLU()(x)
    output = tf.keras.layers.Dense(285)(x)

    # create model
    model = tf.keras.Model(inputs=input, outputs=output)

    # model related processing
    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=optimizer_initial_lr),
        loss=loss_values,
        metrics=[metrics_values]
    )
    return model



# def get_cnn():
#     model = tf.keras.Sequential([
#         tf.keras.layers.Conv2D(32, (7, 7), input_shape=(160, 160, 1)),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.MaxPooling2D((2, 2), strides=2),
#
#         tf.keras.layers.Conv2D(64, (6, 6)),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.MaxPooling2D((2, 2), strides=2),
#
#         tf.keras.layers.Conv2D(128, (5, 5)),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.MaxPooling2D((2, 2), strides=2),
#
#         tf.keras.layers.Conv2D(256, (3, 3)),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.MaxPooling2D((2, 2), strides=2),
#
#         tf.keras.layers.Conv2D(512, (2, 2)),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.MaxPooling2D((2, 2), strides=2),
#
#         tf.keras.layers.Flatten(),
#         tf.keras.layers.Dropout(dropout),
#         tf.keras.layers.Dense(2048),
#         tf.keras.layers.LeakyReLU(),
#         # tf.keras.layers.Dropout(dropout),
#         tf.keras.layers.Dense(1024),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.Dense(512),
#         tf.keras.layers.LeakyReLU(),
#         tf.keras.layers.Dense(285),
#     ])
#     model.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=optimizer_initial_lr),
#         loss=loss_values,
#         metrics=[metrics_values]
#     )
#     return model










# def get_vision_transformer():
#     model = ViT(image_size, patch_size, num_actuators=17*17, dim=projection_dim,
#                 depth=transformer_layers, heads=num_heads, mlp_head_units=mlp_head_units,
#                 transformer_hidden_units=transformer_hidden_units, dim_head=projection_dim,
#                 dropout=dropout, emb_dropout=emb_dropout)
#     model.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=optimizer_initial_lr),
#         loss=loss_values,
#         metrics=[metrics_values]
#     )
#     return model